import torch
import esm
import os
import argparse
import biotite.structure.io as bsio
from tqdm import tqdm
from utils import read_fasta
from transformers import AutoTokenizer, EsmForProteinFolding

from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--sequence", type=str, default=None)
    parser.add_argument("--dataset_input", type=str, default=None)
    parser.add_argument("--fasta_dir", type=str, default=None)
    parser.add_argument("--out_dir", type=str)
    parser.add_argument("--out_file", type=str, default="result.pdb")
    parser.add_argument("--chunk_size", type=int)
    args = parser.parse_args()
    
    # model = esm.pretrained.esmfold_v1()
    # model = model.eval().cuda()
    
    tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
    model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

    model = model.cuda()

    # Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory.
    # Lower sizes will have lower memory requirements at the cost of increased speed.
    if args.chunk_size is not None:
        model.trunk.set_chunk_size(args.chunk_size)
    
    if args.dataset_input is not None:
        proteins = sorted(os.listdir(args.dataset_input))
        bar = tqdm(proteins)
        for p in bar:
            out_file = os.path.join(args.dataset_input, p, f"{p}.ef.pdb")
            if os.path.exists(out_file):
                continue
            bar.set_description(p)
            sequence = read_fasta(os.path.join(args.dataset_input, p, f"{p}.fasta"), "seq")
            # Multimer prediction can be done with chains separated by ':'
            with torch.no_grad():
                output = model.infer_pdb(sequence)

            with open(out_file, "w") as f:
                f.write(output)

            struct = bsio.load_structure(out_file, extra_fields=["b_factor"])
            print(p, struct.b_factor.mean())  # this will be the pLDDT
    
    if args.fasta_dir is not None:
        os.makedirs(args.out_dir, exist_ok=True)
        proteins = sorted(os.listdir(args.fasta_dir))
        bar = tqdm(proteins)
        for p in bar:
            name = p[:-6]
            bar.set_description(name)
            out_file = os.path.join(args.out_dir, f"{name}.ef.pdb")
            if os.path.exists(out_file):
                continue
            bar.set_description(p)
            sequence = read_fasta(os.path.join(args.fasta_dir, p), "seq")
            tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
            # Multimer prediction can be done with chains separated by ':'

            with torch.no_grad():
                output = model(tokenized_input)

            pdb = convert_outputs_to_pdb(output)
            with open(out_file, "w") as f:
                f.write("\n".join(pdb))

            struct = bsio.load_structure(out_file, extra_fields=["b_factor"])
            print(p, struct.b_factor.mean())
    else:
        sequence = args.sequence
        # Multimer prediction can be done with chains separated by ':'

        with torch.no_grad():
            output = model.infer_pdb(sequence)

        with open(args.out_file, "w") as f:
            f.write(output)

        struct = bsio.load_structure(args.out_file, extra_fields=["b_factor"])
        print(struct.b_factor.mean())